from jax import numpy as jnp
from jax import random

key = random.PRNGKey(0)

vec = random.normal(key, (5,))
print('vec', vec.shape)
mat = random.normal(key, (5, 1))
print('mat', mat.shape)

inner_vec = jnp.inner(vec, vec)
print('inner vec', inner_vec)
inner_mat = jnp.inner(mat, mat)
print('inner mat', inner_mat)
dot_mat = jnp.dot(mat, mat.T)
print('dot mat', dot_mat)
